/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file matmul_shape_info.h
 * \brief matmul shape info manager
 */

#ifndef IMPL_MATMUL_MODULES_PARAM_MATMUL_SHAPE_INFO_H
#define IMPL_MATMUL_MODULES_PARAM_MATMUL_SHAPE_INFO_H

#include "../matmul_module.h"

namespace AscendC {
namespace Impl {
namespace Detail {
template <typename IMPL, typename A_TYPE, const auto &MM_CFG>
class MatmulShapeInfo {
    MATMUL_USE_MODULE(MLoop);
    MATMUL_USE_MODULE(NLoop);
    MATMUL_USE_MODULE(KLoop);
public:
    template <bool IS_INTRA_BLOCK = false>
    __aicore__ inline bool IsTransposeA() const
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.isTransposeA;
        } else {
            return MATMUL_CONST_PARAM_VAR.isTransposeA_;
        }
    }

    template <bool IS_INTRA_BLOCK = false>
    __aicore__ inline bool IsTransposeB() const
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.isTransposeB;
        } else {
            return MATMUL_CONST_PARAM_VAR.isTransposeB_;
        }
    }

    template <bool IS_INTRA_BLOCK = false>
    __aicore__ inline uint32_t GetOrgM()
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.M;
        } else {
            return MATMUL_CAST_TO_IMPL()->M_;
        }
    }

    template <bool IS_INTRA_BLOCK = false>
    __aicore__ inline uint32_t GetOrgN()
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.N;
        } else {
            return MATMUL_CAST_TO_IMPL()->N_;
        }
    }

    template <bool IS_INTRA_BLOCK = false>
    __aicore__ inline uint32_t GetOrgKa()
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.Ka;
        } else {
            return MATMUL_CAST_TO_IMPL()->Ka_;
        }
    }

    template <bool IS_INTRA_BLOCK = false>
    __aicore__ inline uint32_t GetOrgKb()
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.Kb;
        } else {
            return MATMUL_CAST_TO_IMPL()->Kb_;
        }
    }

    template <bool IS_INTRA_BLOCK = false>
    __aicore__ inline uint32_t GetOrgKc()
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.Kc;
        } else {
            return MATMUL_CAST_TO_IMPL()->Kc_;
        }
    }

    template <bool IS_INTRA_BLOCK = false, bool IS_BASIC = false>
    __aicore__ inline int32_t GetSingleCoreM() const
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.singleCoreM;
        } else if constexpr (IS_BASIC) {
            return ToMatmulConfig(MM_CFG).singleCoreM;
        } else {
            return MATMUL_CONST_PARAM_VAR.singleCoreM_;
        }
    }

    template <bool IS_INTRA_BLOCK = false, bool IS_BASIC = false>
    __aicore__ inline int32_t GetSingleCoreN() const
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.singleCoreN;
        } else if constexpr (IS_BASIC) {
            return ToMatmulConfig(MM_CFG).singleCoreN;
        } else {
            return MATMUL_CONST_PARAM_VAR.singleCoreN_;
        }
    }

    template <bool IS_INTRA_BLOCK = false, bool IS_BASIC = false>
    __aicore__ inline int32_t GetSingleCoreK() const
    {
        if constexpr (IS_INTRA_BLOCK) {
            return MATMUL_CONST_INTRA_BLOCK.singleCoreK;
        } else if constexpr (IS_BASIC) {
            return ToMatmulConfig(MM_CFG).singleCoreK;
        } else {
            return MATMUL_CONST_PARAM_VAR.singleCoreK_;
        }
    }

    __aicore__ inline uint32_t GetMIter()
    {
        if constexpr (isNormEnableScheduler<A_TYPE, MM_CFG> || IsBmmEnableScheduler<A_TYPE, MM_CFG> ||
                      IsBasicBlockEnable<MM_CFG> || DoMatmulIBShareNorm(MM_CFG) || IsIntrablock<MM_CFG>) {
            return MATMUL_MODULE(MLoop)->GetTotalIter();
        } else {
            return MATMUL_CONST_PARAM_VAR.mIter_;
        }
    }

    __aicore__ inline uint32_t GetNIter()
    {
        if constexpr (isNormEnableScheduler<A_TYPE, MM_CFG> || IsBmmEnableScheduler<A_TYPE, MM_CFG> ||
                      IsBasicBlockEnable<MM_CFG> || DoMatmulIBShareNorm(MM_CFG) || IsIntrablock<MM_CFG>) {
            return MATMUL_MODULE(NLoop)->GetTotalIter();
        } else {
            return MATMUL_CONST_PARAM_VAR.nIter_;
        }
    }

    __aicore__ inline uint32_t GetKIter()
    {
        if constexpr (isNormEnableScheduler<A_TYPE, MM_CFG> || IsBmmEnableScheduler<A_TYPE, MM_CFG> ||
                      IsBasicBlockEnable<MM_CFG> || DoMatmulIBShareNorm(MM_CFG) || IsIntrablock<MM_CFG>) {
            return MATMUL_MODULE(KLoop)->GetTotalIter();
        } else {
            return MATMUL_CONST_PARAM_VAR.kIter_;
        }
    }
};
}  // namespace Detail
}  // namespace Impl
}  // namespace AscendC
#endif // IMPL_MATMUL_MODULES_PARAM_MATMUL_SHAPE_INFO_H
